In [ ]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import GradientNet
from myargs import Args
from myutils import MyUtils

Configurations


In [ ]:
myutils = MyUtils()

args = Args()
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True


# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=2
pretrained_scale=4
growth_rate = 32

#######
# args.test_scene = ['alley_2', 'bamboo_2', 'bandage_2', 'cave_4', 'market_5', 'mountain_1', 'shaman_3', 'sleeping_2', 'temple_3']
args.test_scene = 'bamboo_1'
gradient=False
args.gpu_num = 0
#######

writer_comment = '{}_rgb'.format(args.test_scene)
if gradient == True:
    writer_comment = '{}_gd'.format(args.test_scene)

offset = 0.
if gradient == True: offset = 0.5

args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())

My DataLoader


In [ ]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [ ]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [ ]:
ss = 6
s0 = ss*5
# s0 = 2

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 240
args.training_thresholds = [0,0,0,0,0,s0]
args.training_merge_thresholds = [s0+ss*3*3,s0+ss*2*3, s0+ss*1*3, s0, -1, s0+ss*4*3]
args.power = 0.5



# pretrained = PreTrainedModel(densenet)
# if use_gpu: 
#     pretrained.cuda()


net = GradientNet(densenet=densenet, growth_rate=growth_rate, 
                  transition_scale=transition_scale, pretrained_scale=pretrained_scale,
                 gradient=gradient)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
    mse_merge_losses = [nn.MSELoss().cuda()] * 6
    test_merge_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    mse_merge_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6
    test_merge_losses = [nn.MSELoss()] * 6

In [ ]:
def test_model(epoch, go_through_merge=False, phase='train'):
    if phase == 'train': net.train()
    else: net.eval()
    
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)  
    test_merge_losses_trainphase = [0] * len(args.training_thresholds)
    test_merge_cnts_trainphase   = [0.00001] * len(args.training_thresholds)
    
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test, merged_RGB = net(input_img, go_through_merge=go_through_merge)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt0 = gt_albedo.cpu().data.numpy()
            n,c,h,w = gt0.shape
            gt, display = myutils.processGt(gt0, scale_factor=s, gd=gradient, return_image=True)
            gt_mg, display_mg = myutils.processGt(gt0, scale_factor=s//2, gd=gradient, return_image=True)
            
            if use_gpu: 
                gt = gt.cuda()
                gt_mg = gt_mg.cuda()
            
            if i != 5: 
                loss = mse_losses[i](ft_test[i], gt)
                test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
                test_cnts_trainphase[i] += 1
            
            if go_through_merge != False and i != 4:
                if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
                    if i==5: gt2=gt
                    else: gt2=gt_mg
#                     print(i)
#                     print('merge size', merged_RGB[i].size())
#                     print('gt2 size', gt2.size())
                    loss = mse_merge_losses[i](merged_RGB[i], gt2)
                    test_merge_losses_trainphase[i] += loss.data.cpu().numpy()[0]
                    test_merge_cnts_trainphase[i] += 1
            

            
            if ind == 0: 
                if i != 5:
                    v = v[0].cpu().data.numpy()
                    v = v.transpose(1,2,0)
                    v = v[:,:,0:3]
                    cv2.imwrite('snapshot{}/test-phase_{}-{}-{}.png'.format(args.gpu_num, phase, epoch, i), (v[:,:,::-1]+offset)*255)
                if go_through_merge != False and i != 4:
                    if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
                        v = merged_RGB[i][0].cpu().data.numpy()
                        v = v.transpose(1,2,0)
                        v = v[:,:,0:3]
                        cv2.imwrite('snapshot{}/test-mg-phase_{}-{}-{}.png'.format(args.gpu_num, phase, epoch, i), (v[:,:,::-1]+offset)*255)
                    
    run_losses = test_losses_trainphase
    run_cnts = test_cnts_trainphase
    writer.add_scalars('16M loss', {'test 16M phase {}'.format(phase): np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
    writer.add_scalars('8M loss', {'test 8M phase {}'.format(phase): np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
    writer.add_scalars('4M loss', {'test 4M phase {}'.format(phase): np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
    writer.add_scalars('2M loss', {'test 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
    writer.add_scalars('1M loss', {'test 1M phase {}'.format(phase): np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
    writer.add_scalars('merged loss', {'test merged phase {}'.format(phase): np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch)
    
    run_losses = test_merge_losses_trainphase
    run_cnts = test_merge_cnts_trainphase
    writer.add_scalars('16M loss', {'mg test 16M phase {}'.format(phase): np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
    writer.add_scalars('8M loss', {'mg test 8M phase {}'.format(phase): np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
    writer.add_scalars('4M loss', {'mg test 4M phase {}'.format(phase): np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
    writer.add_scalars('2M loss', {'mg test 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
    writer.add_scalars('1M loss', {'mg test 1M phase {}'.format(phase): np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
    writer.add_scalars('merged loss', {'mg test merged phase {}'.format(phase): np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch)

In [ ]:
# training loop

writer = SummaryWriter(comment='-{}'.format(writer_comment))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None, base_lr=args.base_lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
#         print('para gp', param_group)
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        param_group['lr'] = base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
        if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        

for epoch in range(args.epoches):
#     epoch = 234
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    if epoch < args.training_thresholds[-1]: 
        adjust_learning_rate(optimizer, epoch, beg=0, end=s0-1)
    elif epoch < args.training_merge_thresholds[-1]:
        adjust_learning_rate(optimizer, (epoch-s0)%(ss), beg=0, end=ss-1, base_lr=args.base_lr)
    else:
        adjust_learning_rate(optimizer, epoch, beg=args.training_merge_thresholds[-1], end=args.epoches-1, base_lr=args.base_lr)  
        
        
    if epoch < args.training_thresholds[-1]: go_through_merge = False
    elif epoch >= args.training_merge_thresholds[5]: go_through_merge = '32M'
    elif epoch >= args.training_merge_thresholds[0]: go_through_merge = '16M'
    elif epoch >= args.training_merge_thresholds[1]: go_through_merge = '08M'
    elif epoch >= args.training_merge_thresholds[2]: go_through_merge = '04M'
    elif epoch >= args.training_merge_thresholds[3]: go_through_merge = '02M'

    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    run_merge_losses = [0] * len(args.training_thresholds)
    run_merge_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    if (epoch in args.training_merge_thresholds) == True:
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
        
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        """prepare  training data"""
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        input_img, gt_albedo, gt_shading = Variable(input_img), Variable(gt_albedo), Variable(gt_shading)
        if use_gpu: input_img, gt_albedo, gt_shading = input_img.cuda(), gt_albedo.cuda(), gt_shading.cuda()

        if args.display_curindex % args.display_interval == 0: cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
        
            
        ft_predict, merged_RGB = net(input_img, go_through_merge=go_through_merge)
        for i, threshold in enumerate(args.training_thresholds):
            if epoch >= threshold:
#             if epoch >= 0:
                """prepare resized gt"""
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt0 = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt0.shape
                gt, display = myutils.processGt(gt0, scale_factor=s, gd=gradient, return_image=True)
                gt_mg, display_mg = myutils.processGt(gt0, scale_factor=s//2, gd=gradient, return_image=True)
                if use_gpu: 
                    gt = gt.cuda()
                    gt_mg = gt_mg.cuda()
                if args.display_curindex % args.display_interval == 0:
                    display = display[:,:,0:3]
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), display[:,:,::-1]*255)                
                
                """compute loss"""
                if i != 5: 
                    loss = mse_losses[i](ft_predict[i], gt)
                    run_losses[i] += loss.data.cpu().numpy()[0]
                    loss.backward(retain_graph=True)
                    run_cnts[i] += 1
                
                if go_through_merge != False and i != 4:
                    if ((go_through_merge == '32M') or
                    (go_through_merge == '16M' and i != 5) or  
                    (go_through_merge == '08M' and i != 5 and i > 0) or
                    (go_through_merge == '04M' and i != 5 and i > 1) or
                    (go_through_merge == '02M' and i != 5 and i > 2)):
#                         print(epoch, go_through_merge, i)
                        
#                         print (merged_RGB[i].cpu().data.numpy().max(), merged_RGB[i].cpu().data.numpy().min())
                        if i==5: gt2=gt
                        else: gt2=gt_mg
#                         print(i)
#                         print('merge size', merged_RGB[i].size())
#                         print('gt2 size', gt2.size())
                        loss = mse_merge_losses[i](merged_RGB[i], gt2)
                        run_merge_losses[i] += loss.data.cpu().numpy()[0]
                        loss.backward(retain_graph=True)
                        run_merge_cnts[i] += 1
                
                """save training image"""
                if args.display_curindex % args.display_interval == 0:
                    
                    if i != 5:
                        im = (ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0))+offset) * 255
                        im = im[:,:,0:3]
                        
                        cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
                    
                    if go_through_merge != False and i != 4:
                        if ((go_through_merge == '32M') or
                        (go_through_merge == '16M' and i != 5) or  
                        (go_through_merge == '08M' and i != 5 and i > 0) or
                        (go_through_merge == '04M' and i != 5 and i > 1) or
                        (go_through_merge == '02M' and i != 5 and i > 2)):
                            im = (merged_RGB[i].cpu().data.numpy()[0].transpose((1,2,0))+offset) * 255
                            im = im[:,:,0:3]
                            cv2.imwrite('snapshot{}/train-mg-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    loss_output = ''
    for i,v in enumerate(run_merge_losses):
        if i == len(run_merge_losses)-1: 
            loss_output += 'mg merged: %6f' % (run_merge_losses[i] / run_merge_cnts[i])
            continue
        loss_output += ' mg %2dM: %6f' % ((2**(4-i)), (run_merge_losses[i] / run_merge_cnts[i]))
    print(loss_output)
    
    """save at every epoch"""
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    if (epoch+1) % 5 == 0:
        test_model(epoch, phase='train', go_through_merge=go_through_merge)
        test_model(epoch, phase='test', go_through_merge=go_through_merge)

        writer.add_scalars('16M loss', {'train 16M ': np.array([run_losses[0]/ run_cnts[0]])}, global_step=epoch)  
        writer.add_scalars('8M loss', {'train 8M ': np.array([run_losses[1]/ run_cnts[1]])}, global_step=epoch) 
        writer.add_scalars('4M loss', {'train 4M ': np.array([run_losses[2]/ run_cnts[2]])}, global_step=epoch) 
        writer.add_scalars('2M loss', {'train 2M ': np.array([run_losses[3]/ run_cnts[3]])}, global_step=epoch) 
        writer.add_scalars('1M loss', {'train 1M ': np.array([run_losses[4]/ run_cnts[4]])}, global_step=epoch) 
        writer.add_scalars('merged loss', {'train merged ': np.array([run_losses[5]/ run_cnts[5]])}, global_step=epoch)

Visualize Graph


In [ ]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [ ]:
# x = Variable(torch.zeros(1,3,256,256))
# y = net(x.cuda())
# g = make_dot(y[-1])

In [ ]:
# g.render('net-transition_scale_{}'.format(transition_scale))

In [ ]:


In [ ]: